skip_if_not_m1_mac()

test_that("can create tensors on the MPS device", {
  x <- torch_randn(10, 10, device = "mps")
  expect_tensor(x)
  expect_true(x$device == torch_device("mps", 0))
  
  y <- torch_mm(x, x)
  expect_true(y$device == torch_device("mps", 0))
})

test_that("can allocate a bunch of tensors without OOM", {
  expect_no_error({
    for(i in 1:25) x <- torch_randn(10000, 10000, device="mps")  
  })
})

test_that("can run nn_linear on mps device", {
  skip_if_not_m1_mac()
  
  linear <- nn_linear(10, 1)
  linear$to(device="mps")
  
  x <- torch_randn(10, 10, device='mps')
  
  y <- linear(x)
  expect_tensor(y)
})
